{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c11fde21",
   "metadata": {
    "id": "c11fde21"
   },
   "source": [
    "# Multimodal search using CLIP\n",
    "\n",
    "![mmclip](https://www.researchgate.net/publication/363808556/figure/fig2/AS:11431281086053770@1664048343869/Architectures-of-the-designed-machine-learning-approaches-with-OpenAI-CLIP-model.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06c53ccb-5654-4150-93a2-cdf9ff4d8d26",
   "metadata": {
    "id": "06c53ccb-5654-4150-93a2-cdf9ff4d8d26"
   },
   "source": [
    "### Installing all dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "69fb1627",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "69fb1627",
    "outputId": "7516efba-c93a-4201-ad7a-0d3c4d82d0c1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: pip in /usr/local/lib/python3.10/dist-packages (23.1.2)\n",
      "Requirement already satisfied: install in /usr/local/lib/python3.10/dist-packages (1.3.5)\n",
      "Requirement already satisfied: tantivy==0.20.1 in /usr/local/lib/python3.10/dist-packages (0.20.1)\n"
     ]
    }
   ],
   "source": [
    "!pip install --quiet -U lancedb\n",
    "!pip install --quiet gradio==3.41.2  transformers torch torchvision duckdb\n",
    "!pip install tantivy==0.20.1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d53ade3",
   "metadata": {
    "id": "2d53ade3"
   },
   "source": [
    "## First run setup: Download data and pre-process\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9b7e97f9",
   "metadata": {
    "id": "9b7e97f9"
   },
   "outputs": [],
   "source": [
    "import io\n",
    "import PIL\n",
    "import duckdb\n",
    "import lancedb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5ba75742",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "5ba75742",
    "outputId": "c97f0590-3839-4f67-cb5f-60295142f99b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-02-25 04:51:11--  https://eto-public.s3.us-west-2.amazonaws.com/datasets/diffusiondb_lance.tar.gz\n",
      "Resolving eto-public.s3.us-west-2.amazonaws.com (eto-public.s3.us-west-2.amazonaws.com)... 3.5.79.102, 3.5.82.217, 3.5.77.120, ...\n",
      "Connecting to eto-public.s3.us-west-2.amazonaws.com (eto-public.s3.us-west-2.amazonaws.com)|3.5.79.102|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 6121107645 (5.7G) [application/x-gzip]\n",
      "Saving to: ‘diffusiondb_lance.tar.gz’\n",
      "\n",
      "diffusiondb_lance.t 100%[===================>]   5.70G  69.8MB/s    in 88s     \n",
      "\n",
      "2024-02-25 04:52:39 (66.6 MB/s) - ‘diffusiondb_lance.tar.gz’ saved [6121107645/6121107645]\n",
      "\n",
      "diffusiondb_test/\n",
      "diffusiondb_test/_versions/\n",
      "diffusiondb_test/_latest.manifest\n",
      "diffusiondb_test/data/\n",
      "diffusiondb_test/data/138fc0d8-a806-4b10-84f8-00dc381afdad.lance\n",
      "diffusiondb_test/_versions/1.manifest\n"
     ]
    }
   ],
   "source": [
    "!wget https://eto-public.s3.us-west-2.amazonaws.com/datasets/diffusiondb_lance.tar.gz\n",
    "!tar -xvf diffusiondb_lance.tar.gz\n",
    "!mv diffusiondb_test rawdata.lance"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2fcbf61",
   "metadata": {
    "id": "e2fcbf61"
   },
   "source": [
    "## Create / Open LanceDB Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b3317a3c",
   "metadata": {
    "id": "b3317a3c"
   },
   "outputs": [],
   "source": [
    "import pyarrow.compute as pc\n",
    "import lance\n",
    "\n",
    "db = lancedb.connect(\"~/datasets/demo\")\n",
    "if \"diffusiondb\" in db.table_names():\n",
    "    tbl = db.open_table(\"diffusiondb\")\n",
    "else:\n",
    "    # First data processing and full-text-search index\n",
    "    data = lance.dataset(\"rawdata.lance\").to_table()\n",
    "    # remove null prompts\n",
    "    # tbl = db.create_table(\"diffusiondb\", data.filter(~pc.field(\"prompt\").is_null()), mode=\"overwrite\") # OOM\n",
    "    tbl = db.create_table(\"diffusiondb\", data, mode=\"overwrite\")\n",
    "    tbl.create_fts_index([\"prompt\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7e4fc03",
   "metadata": {
    "id": "d7e4fc03"
   },
   "source": [
    "## Create CLIP embedding function for the text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f8331d87",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 439,
     "referenced_widgets": [
      "ce95ce4a0e5346debb0a5f9202214e51",
      "029f3597c525493893877e3ed4105951",
      "31fc28c33c2a43d6a9f3edfd687a06aa",
      "1952bb6499d247018b3bb8987f4c09bd",
      "fb28818f86144e4b96bffa9f78f22d0e",
      "9cebd2e362d342a8bce9b28f46bd35e2",
      "0e317fbb6dd8470992e912e419a28185",
      "a913f319f6e246039b11500884328720",
      "fdfc01b44f0841dabd221d68752d3d72",
      "ab8d874c0a184ca2ba81114e3eaf7409",
      "d8b428b1b0464413b4fe5a9d4f111c57",
      "daacb1e1a2814655a7d09f65ff8452aa",
      "cf61f81503ae4dc08fa09394f3f5cccd",
      "ebeadf1ec0c1418499ddbde06f4996f7",
      "d005ee7956a44436957aa90c3300424c",
      "f95919ae1e9c49bf86b4a5af88e0c2da",
      "46ad1889dcac48cdbebad4f05ef09e2b",
      "cd263d8c5fd84a35a91d874699f5d51a",
      "304da0a1d41b40ac972c8c290c643870",
      "40cac08994364eb09a9cc2f9f0bde825",
      "aa75ab10ee4f4f85bc83221d7cfd36b6",
      "ac69897159f84312b11add482fcf6d07",
      "b3ea9b5595e1427a802318a98b514203",
      "a8e504801bf54ffb87d250e9321d6aa2",
      "4cb8ea17edd645e1aa275602a4abf8e7",
      "9d0cad1fdd564482bf011f7fef778d4b",
      "25b23989c05e483bbe56cb6605785eb1",
      "5eea8e0e481d4f34ae1366bd38ee4434",
      "6ad1a3c9c73341c5a2390ca9dd62668d",
      "1d189babe20742409805cb9fa555beae",
      "76e8ff4bd1ac452ba8251dc1f3228dfd",
      "0832ef0b231c4a078603ce3f43781d5f",
      "dc9a73097e614bc18aa09ed968050a62",
      "772c4c47a58543a8a1a178f40597920f",
      "455e6832a03a4c1aa5aab48a9109d91b",
      "0bc11bbd6038444cbef4721bfebbd531",
      "6d853fdc57cd4f35b4d0b6385177e8b7",
      "3c2a47c25bd142adaa5fde9e90d66151",
      "a8b9d35c4e424ba9b8a0c6518ad6a1b4",
      "7a92918e31854bf4bb6a3e595b7b1a07",
      "f8e42c0f4f8c4a6d910387c58676f660",
      "78138fe07e094f54a8f372ac3e7dd96f",
      "f1d094606e374cf1a86c8855c421ce01",
      "6f3b8a6c0f9a4f33ae3b14d34fc2d4ac",
      "d9e17ca730e34e44b78bd6a3a5307c7a",
      "73e2ca308ae441e68c632dae241da9b0",
      "124a9c8793f340fab919e90bcdcfd208",
      "776cf021de614b668bdecf6c2f22bc07",
      "5cabe7eba19f4fe5a6ebb0a8e019130f",
      "320d8f855cdf4223a7efa9c11a9c7d4b",
      "98a3539fa99c44f09444a05f65a88a33",
      "08784f5ae8d0488980e188dceaf85172",
      "e08c37c9b5cf4431b8ac3275f3884ddc",
      "428c0c75d4d94234bd6eb0515e602dba",
      "5ed4f7ff855440b28d8c0cb443260022",
      "8fe30c2025df48bd825ea00bbf623326",
      "6f37e35f95a14b5ea1046e672c491c52",
      "95923a6c40944ac1b8cf917d8572478a",
      "10107106ec1e4fd4b91736b7a9c985c5",
      "20d453f0fd2c44bd81a1883bc7584074",
      "085bcef04fa94ecb8fcf6df74602c8ca",
      "ef32f6da77214a1caa24962a37e706ed",
      "3f4e3aa5b1f34b518711148c3e1b1ec7",
      "67fdff76b75c4703a483e8cbe0d5ea55",
      "1846b4ef6dd94ac29d2ddf90e5a1ce6d",
      "29ae5071deec41859a57e5cf6cb37ad5",
      "eb38b4cae0e44591acfcedf07cc134cb",
      "55d0ca31f24b48ef911ae02fb089c1ad",
      "caa0ac3e4c664879a0cc0f6bd2663fa7",
      "7e08e9d81ad24353886a5f26ac4d1f1c",
      "f8250d657f5240efab2ba1725826ee72",
      "5ed78c44b67f435e90b9abca36b86cb8",
      "6b6b143e4e5045d6a5710886b3876b1c",
      "ce033a7365e640c58f7835007655460b",
      "1a4efbd30ac34718b8977845a66369f4",
      "87445801dca54689ae4dbca714b4849e",
      "19878e3fe3e3479e8bfbdd3f52a12d33",
      "26904ea599a140f0b7ab4e74b2d1f608",
      "65c145203efd4d7c9a3c1b8eb3a6d306",
      "cd67afa173284b84a3d617b07d95f9d0",
      "df2f32a0509b41f99e7f0c0fcb08ff35",
      "a32279b40489404b842122ad4168a1a7",
      "20943327f8ad4e818c22543ecc45a482",
      "b6538a0a3b1043adb5cb5c7f07a57938",
      "7cac0fb25dba43c38d609b7fe47bb5da",
      "3bda882470684f8ea32ac008d86a7c58",
      "1338c410755b41688a101be17dcb7a84",
      "59d983edf3a641b1b442ed17128ca243"
     ]
    },
    "id": "f8331d87",
    "outputId": "8622151e-5547-4eec-df17-baf7c1d98c03"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n",
      "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
      "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
      "You will be able to reuse this secret in all of your notebooks.\n",
      "Please note that authentication is recommended but still optional to access public models or datasets.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce95ce4a0e5346debb0a5f9202214e51",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/568 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "daacb1e1a2814655a7d09f65ff8452aa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b3ea9b5595e1427a802318a98b514203",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "772c4c47a58543a8a1a178f40597920f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d9e17ca730e34e44b78bd6a3a5307c7a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8fe30c2025df48bd825ea00bbf623326",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eb38b4cae0e44591acfcedf07cc134cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "26904ea599a140f0b7ab4e74b2d1f608",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast\n",
    "\n",
    "MODEL_ID = \"openai/clip-vit-base-patch32\"\n",
    "\n",
    "tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)\n",
    "model = CLIPModel.from_pretrained(MODEL_ID)\n",
    "processor = CLIPProcessor.from_pretrained(MODEL_ID)\n",
    "\n",
    "\n",
    "def embed_func(query):\n",
    "    inputs = tokenizer([query], padding=True, return_tensors=\"pt\")\n",
    "    text_features = model.get_text_features(**inputs)\n",
    "    return text_features.detach().numpy()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "82c50eaf",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "82c50eaf",
    "outputId": "027a5d0f-8035-4642-e281-ae0771b9220f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "prompt: string\n",
       "seed: uint32\n",
       "step: uint16\n",
       "cfg: float\n",
       "sampler: string\n",
       "width: uint16\n",
       "height: uint16\n",
       "timestamp: timestamp[s]\n",
       "image_nsfw: float\n",
       "prompt_nsfw: float\n",
       "vector: fixed_size_list<item: float>[512]\n",
       "  child 0, item: float\n",
       "image: binary"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tbl.schema\n",
    "# tbl.to_pandas().head() # OOM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e4d7a54",
   "metadata": {
    "id": "5e4d7a54"
   },
   "source": [
    "\n",
    "## Search functions for Gradio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "10b8de6d",
   "metadata": {
    "id": "10b8de6d"
   },
   "outputs": [],
   "source": [
    "def find_image_vectors(query):\n",
    "    emb = embed_func(query)\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"db = lancedb.connect('~/datasets/demo')\\n\"\n",
    "        \"tbl = db.open_table('diffusiondb')\\n\\n\"\n",
    "        f\"embedding = embed_func('{query}')\\n\"\n",
    "        \"tbl.search(embedding).limit(9).to_df()\"\n",
    "    )\n",
    "    return (_extract(tbl.search(emb).limit(9).to_pandas()), code)\n",
    "\n",
    "\n",
    "def find_image_keywords(query):\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"db = lancedb.connect('~/datasets/demo')\\n\"\n",
    "        \"tbl = db.open_table('diffusiondb')\\n\\n\"\n",
    "        f\"tbl.search('{query}').limit(9).to_df()\"\n",
    "    )\n",
    "    return (_extract(tbl.search(query).limit(9).to_pandas()), code)\n",
    "\n",
    "\n",
    "def find_image_sql(query):\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"import duckdb\\n\"\n",
    "        \"db = lancedb.connect('~/datasets/demo')\\n\"\n",
    "        \"tbl = db.open_table('diffusiondb')\\n\\n\"\n",
    "        \"diffusiondb = tbl.to_lance()\\n\"\n",
    "        f\"duckdb.sql('{query}').to_df()\"\n",
    "    )\n",
    "    diffusiondb = tbl.to_lance()\n",
    "    return (_extract(duckdb.sql(query).to_df()), code)\n",
    "\n",
    "\n",
    "def _extract(df):\n",
    "    image_col = \"image\"\n",
    "    return [\n",
    "        (PIL.Image.open(io.BytesIO(row[image_col])), row[\"prompt\"])\n",
    "        for _, row in df.iterrows()\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61aaf19b",
   "metadata": {
    "id": "61aaf19b"
   },
   "source": [
    "## Setup Gradio interface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b6f40300",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 630
    },
    "id": "b6f40300",
    "outputId": "1c9a7f57-a32d-42a7-a61f-b2fc1008c3b3"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-11-b0f437326cca>:21: GradioDeprecationWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n",
      "  gallery = gr.Gallery(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
      "Running on public URL: https://94fe48d6801f7e6c4d.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://94fe48d6801f7e6c4d.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gradio as gr\n",
    "\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    with gr.Row():\n",
    "        with gr.Tab(\"Embeddings\"):\n",
    "            vector_query = gr.Textbox(value=\"portraits of a person\", show_label=False)\n",
    "            b1 = gr.Button(\"Submit\")\n",
    "        with gr.Tab(\"Keywords\"):\n",
    "            keyword_query = gr.Textbox(value=\"ninja turtle\", show_label=False)\n",
    "            b2 = gr.Button(\"Submit\")\n",
    "        with gr.Tab(\"SQL\"):\n",
    "            sql_query = gr.Textbox(\n",
    "                value=\"SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9\",\n",
    "                show_label=False,\n",
    "            )\n",
    "            b3 = gr.Button(\"Submit\")\n",
    "    with gr.Row():\n",
    "        code = gr.Code(label=\"Code\", language=\"python\")\n",
    "    with gr.Row():\n",
    "        gallery = gr.Gallery(\n",
    "            label=\"Found images\", show_label=False, elem_id=\"gallery\"\n",
    "        ).style(columns=[3], rows=[3], object_fit=\"contain\", height=\"auto\")\n",
    "\n",
    "    b1.click(find_image_vectors, inputs=vector_query, outputs=[gallery, code])\n",
    "    b2.click(find_image_keywords, inputs=keyword_query, outputs=[gallery, code])\n",
    "    b3.click(find_image_sql, inputs=sql_query, outputs=[gallery, code])\n",
    "\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "pEdiGi9EW0ZO",
   "metadata": {
    "id": "pEdiGi9EW0ZO"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  },
  "vscode": {
   "interpreter": {
    "hash": "511a7c77cb034b09af5465c01316a0f4bb20176d139e60e6d7915f9a637a5037"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
